from operator import itemgetter

import dotenv
import weaviate
from langchain_community.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_weaviate import WeaviateVectorStore
from weaviate.auth import AuthApiKey

dotenv.load_dotenv()

def format_qa_pairs(question: str, answer: str) -> str:
    """格式化传递的问题+答案为单个字符串"""
    return f"Question: {question}\nAnswer:{answer}\n\n".strip()

# 1. 定义分解子问题的prompt
decomposition_prompt = ChatPromptTemplate.from_template(
    "你是一个乐于助人的AI助理，可以针对一个输入的问题生成多个相关的子问题。\n"
    "目标是将输入的问题分解成一组可以独立回答的子问题或者子任务。\n"
    "生成与问题相关的多个搜索查询：{question}\n"
    "并使用换行符进行分割，输出（3个子问题/子查询）：")

# 2. 构建分解问题链
decomposition_chain = (
    {"question": RunnablePassthrough()}
    | decomposition_prompt
    | ChatOpenAI(model_name="kimi-k2-0711-preview", temperature=0)
    | StrOutputParser()
    | (lambda x: x.strip().split("\n"))
)

# 3. 构建向量数据库与检索器
# 创建客户端连接（使用新的connect_to_weaviate_cloud方法）
client = weaviate.connect_to_weaviate_cloud(
    cluster_url="https://zabwh0mbt4errmvpknamq.c0.asia-southeast1.gcp.weaviate.cloud",
    auth_credentials=AuthApiKey("b2o4OGQxcmptMTZEWmJ5VV9udE5xSXBzQW04dUlDZ0JSS0d1ay9FQlhXdEtyMDR4OUFVNzc0eG9mU3dnPV92MjAw")
)

db = WeaviateVectorStore(
    client=client,
    index_name="myleane",
    text_key="text",
    embedding=QianfanEmbeddingsEndpoint(),
)

retriever = db.as_retriever(search_type="mmr")

# 4. 执行提问获取子问题
question = "关于LLMOps应用配置的文档有哪些"
sub_questions = decomposition_chain.invoke(question)

# 5. 构建迭代问答链：提示模板+链
prompt = ChatPromptTemplate.from_template("""这是你需要回答的问题
---
{question}
---

这是所有可用的背景问题和答案对：
---
{qa_pairs}
---

这是与问题相关的额外背景信息：
---
{context}
---""")

chain = ({
            "question":itemgetter("question"),
            "qa_pairs":itemgetter("qa_pairs"),
            "context":itemgetter("question") | retriever
         }
         | prompt
         | ChatOpenAI(model_name="kimi-k2-0711-preview", temperature=0)
         | StrOutputParser())


# 6.循环遍历所有的子问题进行搜索并获取答案
qa_pairs = ""
for sub_question in sub_questions:
    answer = chain.invoke({"question": sub_question, "qa_pairs":qa_pairs})
    qa_pairs = format_qa_pairs(sub_question, answer)
    qa_pairs += '\n-----\n' + qa_pairs
    print(f"问题：{sub_question}")
    print(f"答案: {answer}")